package dns

import (
	"context"
	"errors"
	"net"
	"reflect"
	"testing"
	"time"

	"github.com/TecharoHQ/anubis/lib/store/memory"
)

// newTestDNS is a helper function to create a new Dns object with an in-memory cache for testing.
func newTestDNS(forwardTTL int, reverseTTL int) *Dns {
	ctx := context.Background()
	memStore := memory.New(ctx)
	cache := NewDNSCache(forwardTTL, reverseTTL, memStore)
	return New(ctx, cache)
}

// mockLookupAddr is a mock implementation of the net.LookupAddr function.
func mockLookupAddr(addr string) ([]string, error) {
	switch addr {
	case "8.8.8.8":
		return []string{"dns.google."}, nil
	case "1.1.1.1":
		return []string{"one.one.one.one."}, nil
	case "208.67.222.222":
		return []string{"resolver1.opendns.com."}, nil
	case "9.9.9.9":
		return nil, &net.DNSError{Err: "no such host", Name: "9.9.9.9", IsNotFound: true}
	case "1.2.3.4":
		return nil, errors.New("unknown error")
	default:
		return nil, &net.DNSError{Err: "no such host", Name: addr, IsNotFound: true}
	}
}

// mockLookupHost is a mock implementation of the net.LookupHost function.
func mockLookupHost(host string) ([]string, error) {
	switch host {
	case "dns.google":
		return []string{"8.8.8.8", "8.8.4.4"}, nil
	case "one.one.one.one":
		return []string{"1.1.1.1", "1.0.0.1"}, nil
	case "resolver1.opendns.com":
		return []string{"208.67.222.222"}, nil
	case "example.com":
		return nil, &net.DNSError{Err: "no such host", Name: "example.com", IsNotFound: true}
	default:
		return nil, &net.DNSError{Err: "no such host", Name: host, IsNotFound: true}
	}
}

func TestMain(m *testing.M) {
	// Before all tests
	originalLookupAddr := DNSLookupAddr
	originalLookupHost := DNSLookupHost

	DNSLookupAddr = mockLookupAddr
	DNSLookupHost = mockLookupHost

	// Run tests
	exitCode := m.Run()

	// After all tests
	DNSLookupAddr = originalLookupAddr
	DNSLookupHost = originalLookupHost

	// Exit
	if exitCode != 0 {
		panic(exitCode)
	}
}

func TestDns_ArpaReverseIP(t *testing.T) {
	d := newTestDNS(0, 0)
	tests := []struct {
		name    string
		ip      string
		want    string
		wantErr bool
	}{
		{"ipv4", "192.0.2.1", "1.2.0.192", false},
		{"ipv6", "2001:db8::1", "1.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2", false},
		{"invalid ip", "invalid", "invalid", true},
		{"ipv4-mapped ipv6", "::ffff:192.0.2.1", "1.2.0.192", false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := d.ArpaReverseIP(tt.ip)
			if (err != nil) != tt.wantErr {
				t.Errorf("ArpaReverseIP() error = %v, wantErr %v", err, tt.wantErr)
				return
			}
			if got != tt.want {
				t.Errorf("ArpaReverseIP() = %v, want %v", got, tt.want)
			}
		})
	}
}

func TestDns_ReverseDNS(t *testing.T) {
	d := newTestDNS(1, 1) // short TTL for testing cache

	// First call - cache miss
	t.Run("cache miss", func(t *testing.T) {
		got, err := d.ReverseDNS("8.8.8.8")
		if err != nil {
			t.Fatalf("ReverseDNS() error = %v", err)
		}
		want := []string{"dns.google"}
		if !reflect.DeepEqual(got, want) {
			t.Errorf("ReverseDNS() = %v, want %v", got, want)
		}
	})

	// Second call - cache hit
	t.Run("cache hit", func(t *testing.T) {
		// Temporarily replace lookup function to ensure cache is used
		originalLookupAddr := DNSLookupAddr
		DNSLookupAddr = func(addr string) ([]string, error) {
			return nil, errors.New("should not be called")
		}
		defer func() { DNSLookupAddr = originalLookupAddr }()

		got, err := d.ReverseDNS("8.8.8.8")
		if err != nil {
			t.Fatalf("ReverseDNS() error = %v", err)
		}
		want := []string{"dns.google"}
		if !reflect.DeepEqual(got, want) {
			t.Errorf("ReverseDNS() = %v, want %v", got, want)
		}
	})

	// Test cache expiration
	t.Run("cache expiration", func(t *testing.T) {
		time.Sleep(2 * time.Second)
		// Now the cache should be expired
		// We expect the mock to be called again
		// To test this we will change the mock to return something different
		originalLookupAddr := DNSLookupAddr
		DNSLookupAddr = func(addr string) ([]string, error) {
			if addr == "8.8.8.8" {
				return []string{"expired.google."}, nil
			}
			return mockLookupAddr(addr)
		}
		defer func() { DNSLookupAddr = originalLookupAddr }()

		got, err := d.ReverseDNS("8.8.8.8")
		if err != nil {
			t.Fatalf("ReverseDNS() error = %v", err)
		}
		want := []string{"expired.google"}
		if !reflect.DeepEqual(got, want) {
			t.Errorf("ReverseDNS() = %v, want %v", got, want)
		}
	})

	// Test not found
	t.Run("not found", func(t *testing.T) {
		got, err := d.ReverseDNS("9.9.9.9")
		if err != nil {
			t.Fatalf("ReverseDNS() error = %v", err)
		}
		if len(got) != 0 {
			t.Errorf("ReverseDNS() = %v, want empty slice", got)
		}
	})
}

func TestDns_LookupHost(t *testing.T) {
	d := newTestDNS(1, 1)

	t.Run("cache miss", func(t *testing.T) {
		got, err := d.LookupHost("dns.google")
		if err != nil {
			t.Fatalf("LookupHost() error = %v", err)
		}
		want := []string{"8.8.8.8", "8.8.4.4"}
		if !reflect.DeepEqual(got, want) {
			t.Errorf("LookupHost() = %v, want %v", got, want)
		}
	})

	t.Run("cache hit", func(t *testing.T) {
		originalLookupHost := DNSLookupHost
		DNSLookupHost = func(host string) ([]string, error) {
			return nil, errors.New("should not be called")
		}
		defer func() { DNSLookupHost = originalLookupHost }()

		got, err := d.LookupHost("dns.google")
		if err != nil {
			t.Fatalf("LookupHost() error = %v", err)
		}
		want := []string{"8.8.8.8", "8.8.4.4"}
		if !reflect.DeepEqual(got, want) {
			t.Errorf("LookupHost() = %v, want %v", got, want)
		}
	})

	t.Run("cache expiration", func(t *testing.T) {
		time.Sleep(2 * time.Second)
		originalLookupHost := DNSLookupHost
		DNSLookupHost = func(host string) ([]string, error) {
			if host == "dns.google" {
				return []string{"9.9.9.9"}, nil
			}
			return mockLookupHost(host)
		}
		defer func() { DNSLookupHost = originalLookupHost }()

		got, err := d.LookupHost("dns.google")
		if err != nil {
			t.Fatalf("LookupHost() error = %v", err)
		}
		want := []string{"9.9.9.9"}
		if !reflect.DeepEqual(got, want) {
			t.Errorf("LookupHost() = %v, want %v", got, want)
		}
	})

	t.Run("not found", func(t *testing.T) {
		got, err := d.LookupHost("example.com")
		if err != nil {
			t.Fatalf("LookupHost() error = %v", err)
		}
		if len(got) != 0 {
			t.Errorf("LookupHost() = %v, want empty slice", got)
		}
	})
}

func TestDns_VerifyFCrDNS(t *testing.T) {
	d := newTestDNS(1, 1)

	// Helper to convert string to *string
	p := func(s string) *string {
		return &s
	}

	tests := []struct {
		name    string
		ip      string
		pattern *string
		want    bool
	}{
		// Cases without pattern
		{"valid no pattern", "8.8.8.8", nil, true},
		{"valid partial no pattern", "1.1.1.1", nil, true},
		{"not found no pattern", "9.9.9.9", nil, true},
		{"unknown error no pattern", "1.2.3.4", nil, false},

		// Cases with pattern
		{"valid match", "8.8.8.8", p(`.*\.google$`), true},
		{"valid no match", "8.8.8.8", p(`\.com$`), false},
		{"not found with pattern", "9.9.9.9", p(".*"), false},
		{"unknown error with pattern", "1.2.3.4", p(".*"), false},
		{"invalid pattern", "8.8.8.8", p(`[`), false},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if got := d.VerifyFCrDNS(tt.ip, tt.pattern); got != tt.want {
				t.Errorf("VerifyFCrDNS() = %v, want %v", got, tt.want)
			}
		})
	}

	t.Run("reverse cache hit", func(t *testing.T) {
		// Prime the cache
		if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
			t.Fatalf("VerifyFCrDNS() priming failed, got %v, want true", got)
		}

		// Now test with a failing lookup to ensure cache is used
		originalLookupAddr := DNSLookupAddr
		DNSLookupAddr = func(addr string) ([]string, error) {
			return nil, errors.New("should not be called")
		}
		defer func() { DNSLookupAddr = originalLookupAddr }()

		if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
			t.Errorf("VerifyFCrDNS() = %v, want true", got)
		}
	})

	t.Run("forward cache hit", func(t *testing.T) {
		// Prime the cache
		if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
			t.Fatalf("VerifyFCrDNS() priming failed, got %v, want true", got)
		}

		// Now test with a failing lookup to ensure cache is used
		originalLookupHost := DNSLookupHost
		DNSLookupHost = func(host string) ([]string, error) {
			return nil, errors.New("should not be called")
		}
		defer func() { DNSLookupHost = originalLookupHost }()

		if got := d.VerifyFCrDNS("8.8.8.8", nil); got != true {
			t.Errorf("VerifyFCrDNS() = %v, want true", got)
		}
	})
}
